Decision Trees

Setup

Today’s data concerns strains of cannabis, which have the types of sativa, indica, or hybrid:

# A tibble: 2,351 × 69
   Strain Type  Rating Effects Flavor Creative Energetic Tingly Euphoric Relaxed
   <chr>  <chr>  <dbl> <chr>   <chr>     <dbl>     <dbl>  <dbl>    <dbl>   <dbl>
 1 100-Og hybr…    4   Creati… Earth…        1         1      1        1       1
 2 98-Wh… hybr…    4.7 Relaxe… Flowe…        1         1      0        0       1
 3 1024   sati…    4.4 Uplift… Spicy…        1         1      0        0       1
 4 13-Da… hybr…    4.2 Tingly… Apric…        1         0      1        0       1
 5 24K-G… hybr…    4.6 Happy,… Citru…        0         0      0        1       1
 6 3-Bea… indi…    0   None    None          0         0      0        0       0
 7 3-Kin… hybr…    4.4 Relaxe… Earth…        0         0      0        1       1
 8 303-Og indi…    4.2 Relaxe… Citru…        0         0      0        1       1
 9 3D-Cbd sati…    4.6 Uplift… Earth…        0         0      0        0       1
10 3X-Cr… indi…    4.4 Relaxe… Earth…        0         0      1        1       1
# ℹ 2,341 more rows
# ℹ 59 more variables: Aroused <dbl>, Happy <dbl>, Uplifted <dbl>,
#   Hungry <dbl>, Talkative <dbl>, Giggly <dbl>, Focused <dbl>, Sleepy <dbl>,
#   Dry <dbl>, Mouth <dbl>, Earthy <dbl>, Sweet <dbl>, Citrus <dbl>,
#   Flowery <dbl>, Violet <dbl>, Diesel <dbl>, `Spicy/Herbal` <dbl>,
#   Sage <dbl>, Woody <dbl>, Apricot <dbl>, Grapefruit <dbl>, Orange <dbl>,
#   Pungent <dbl>, Grape <dbl>, Pine <dbl>, Skunk <dbl>, Berry <dbl>, …

Data Cleaning

cann_clean <- cann %>%
  mutate(
    Type = factor(Type)
  ) %>%
  drop_na(-Strain, -Effects, -Flavor)

Removing ALL Missing Values

Notice I’m using a - before the variables I do not want to remove missing values from.

Setting-up CV & Declaring a Recipe

cann_cvs <- vfold_cv(cann_clean, v = 5)

cann_recipe <- recipe(Type ~ ., 
                     data = cann_clean) %>%
  step_rm(Strain, Effects, Flavor)

You can use a . to stand in for the name of every variable in the dataset!

You can remove variables you do not want to use as predictors using the step_rm() function.

Previewing the Recipe

── Recipe ───────────────────────────────────────────────────────────────────────────────────────────────────

── Inputs 
Number of variables by role
outcome:    1
predictor: 68

── Operations 
• Variables removed: Strain, Effects, Flavor

Logistic Regression

logit_mod <- logistic_reg() %>%
  set_engine("glm") %>%
  set_mode("classification")

logit_wflow <- workflow() %>%
  add_recipe(cann_recipe) %>%
  add_model(logit_mod)

logit_fit <- logit_wflow %>%
  fit_resamples(cann_cvs)


→ A | warning: ! Logistic regression is intended for modeling binary outcomes, but there are 3 levels in the outcome.
               ℹ If this is unintended, adjust outcome levels accordingly or see the `multinom_reg()` function.
→ B | warning: prediction from rank-deficient fit; attr(*, "non-estim") has doubtful cases
→ C | error:   Failed to compute `roc_auc()`. 

What happened?


Problem 1: There are three categories in Type. How do we interpret the log-odds for multiple groups?


Problem 2: The model is trying to fit 65 predictor coefficients! That’s a LOT.

Discriminant Analysis

lda_mod <- discrim_linear() %>%
  set_engine("MASS") %>%
  set_mode("classification")

lda_wflow <- workflow() %>%
  add_recipe(cann_recipe) %>%
  add_model(lda_mod)

lda_fit <- lda_wflow %>%
  fit_resamples(cann_cvs)


→ A | warning: variables are collinear
→ B | warning: no non-missing arguments to min; returning Inf
→ C | error:   variables 15 16 appear to be constant within groups

What happened now?


Problem 1: There are still 65 predictors, i.e., 65 dimensions!

Problem 2: Some of these predictors contain duplicate information.

dplyr::select(cann_clean, Dry, Mouth) %>%
  arrange(desc(Dry))
# A tibble: 2,305 × 2
     Dry Mouth
   <dbl> <dbl>
 1     1     1
 2     0     0
 3     0     0
 4     0     0
 5     0     0
 6     0     0
 7     0     0
 8     0     0
 9     0     0
10     0     0
# ℹ 2,295 more rows

KNN

knn_mod <- nearest_neighbor(neighbors = 5) %>%
  set_engine("kknn") %>%
  set_mode("classification")

knn_wflow <- workflow() %>%
  add_recipe(cann_recipe) %>%
  add_model(knn_mod)

knn_fit <- knn_wflow %>%
  fit_resamples(cann_cvs)



No errors!!!!

How’d we do?

knn_fit <- knn_wflow %>%
  fit_resamples(cann_cvs,
                metrics = metric_set(accuracy, roc_auc, precision, recall)
                )


knn_fit %>% collect_metrics()
# A tibble: 4 × 6
  .metric   .estimator  mean     n std_err .config             
  <chr>     <chr>      <dbl> <int>   <dbl> <chr>               
1 accuracy  multiclass 0.496     5 0.00971 Preprocessor1_Model1
2 precision macro      0.464     5 0.0110  Preprocessor1_Model1
3 recall    macro      0.464     5 0.0102  Preprocessor1_Model1
4 roc_auc   hand_till  0.678     5 0.0107  Preprocessor1_Model1

Woof!

Decision Trees

Let’s play 20 questions

Is this strain described as “Energetic”?

Let’s play 20 questions.

Is this strain described as tasting like “Pineapple”?

Akinator says…

An image of a genie putting their hands to their head and contemplating what to guess.

A image of a genie guessing the type of marajuana as sativa.

Declaring a Model

tree_mod <- decision_tree() %>%
  set_engine("rpart") %>%
  set_mode("classification")


Specifying a Workflow

tree_wflow <- workflow() %>%
  add_model(tree_mod) %>% 
  add_recipe(cann_recipe)

Decision Trees

tree_fit <- tree_wflow %>%
  fit_resamples(cann_cvs,
                metrics = metric_set(accuracy, roc_auc, precision, recall)
                )


tree_fit %>% collect_metrics()
# A tibble: 4 × 6
  .metric   .estimator  mean     n std_err .config             
  <chr>     <chr>      <dbl> <int>   <dbl> <chr>               
1 accuracy  multiclass 0.622     5 0.00487 Preprocessor1_Model1
2 precision macro      0.602     5 0.00647 Preprocessor1_Model1
3 recall    macro      0.571     5 0.0107  Preprocessor1_Model1
4 roc_auc   hand_till  0.754     5 0.00250 Preprocessor1_Model1

Inspecting the Fit

tree_fit_1 <- tree_wflow %>%
  fit(cann_clean)

tree_fit_1$fit
$actions
$actions$model
$spec
Decision Tree Model Specification (classification)

Computational engine: rpart 


$formula
NULL

attr(,"class")
[1] "action_model" "action_fit"   "action"      


$fit
parsnip model object

n= 2305 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 2305 1118 hybrid (0.51496746 0.29804772 0.18698482)  
   2) Sleepy< 0.5 1580  629 hybrid (0.60189873 0.14303797 0.25506329)  
     4) Energetic< 0.5 981  344 hybrid (0.64933741 0.20591233 0.14475025) *
     5) Energetic>=0.5 599  285 hybrid (0.52420701 0.04006678 0.43572621)  
      10) Relaxed>=0.5 281  104 hybrid (0.62989324 0.04982206 0.32028470) *
      11) Relaxed< 0.5 318  147 sativa (0.43081761 0.03144654 0.53773585) *
   3) Sleepy>=0.5 725  264 indica (0.32551724 0.63586207 0.03862069) *

attr(,"class")
[1] "stage_fit" "stage"    

Decision Trees

tree_fitted <- tree_fit_1 %>% 
  extract_fit_parsnip()

rpart.plot(tree_fitted$fit)

Note the rpart.plot() function lives in the rpart.plot package!

What might we change?

Hyperparameters!

Tree Depth

tree_depth: How many splits will we “allow” the tree to make?

  • If we allowed infinite splits, we’d end up with only on observation in each “leaf”. This will badly overfit the training data!
  • If we allow only one split, our accuracy won’t be that great.
  • Default in rpart: Up to 30

Minimum Observations

min_n: How many observations have to be in a “leaf” for us to be allowed to split it further?

  • If min_n is too small, we’re overfitting.
  • If min_n is too big, we’re not allowing enough flexibility.
  • Default in rpart: 20

Tuning min_n

Let’s try varying the minimum number of observations in a leaf between 2 and 20.

tree_grid <- grid_regular(min_n(c(2,20)),
                          levels = 4)

tree_grid
# A tibble: 4 × 1
  min_n
  <int>
1     2
2     8
3    14
4    20

Start wide!

Tuning with cross-validation takes a long time! Do yourself a favor and start with a small but wide grid.

Setting up the tuning

tree_mod <- decision_tree(min_n = tune()) %>%
  set_engine("rpart") %>%
  set_mode("classification")

tree_wflow <- workflow() %>%
  add_recipe(cann_recipe) %>%
  add_model(tree_mod)

tree_grid_search <-
  tune_grid(
    tree_wflow,
    resamples = cann_cvs,
    grid = tree_grid
  )

tuning_metrics <- tree_grid_search %>% collect_metrics()

Inspecting the fits

tuning_metrics
# A tibble: 8 × 7
  min_n .metric  .estimator  mean     n std_err .config             
  <int> <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
1     2 accuracy multiclass 0.622     5 0.00487 Preprocessor1_Model1
2     2 roc_auc  hand_till  0.754     5 0.00250 Preprocessor1_Model1
3     8 accuracy multiclass 0.622     5 0.00487 Preprocessor1_Model2
4     8 roc_auc  hand_till  0.754     5 0.00250 Preprocessor1_Model2
5    14 accuracy multiclass 0.622     5 0.00487 Preprocessor1_Model3
6    14 roc_auc  hand_till  0.754     5 0.00250 Preprocessor1_Model3
7    20 accuracy multiclass 0.622     5 0.00487 Preprocessor1_Model4
8    20 roc_auc  hand_till  0.754     5 0.00250 Preprocessor1_Model4

What’s the best choice of min_n?

What else can we change?

How is rpart choosing to stop splitting?

cost complexity = how much metric gain is “worth it” to do another split?

  • Default: Split must increase accuracy by at least 0.01.

Cost complexity

tree_grid <- grid_regular(cost_complexity(),
                          tree_depth(),
                          min_n(), 
                          levels = 2)

tree_grid
# A tibble: 8 × 3
  cost_complexity tree_depth min_n
            <dbl>      <int> <int>
1    0.0000000001          1     2
2    0.1                   1     2
3    0.0000000001         15     2
4    0.1                  15     2
5    0.0000000001          1    40
6    0.1                   1    40
7    0.0000000001         15    40
8    0.1                  15    40

Cost Complexity

tree_mod <- decision_tree(cost_complexity = tune(),
                          tree_depth = tune(),
                          min_n = tune()) %>%
  set_engine("rpart") %>%
  set_mode("classification")

tree_wflow <- workflow() %>%
  add_model(tree_mod) %>% 
  add_recipe(cann_recipe)

tree_grid_search <-
  tune_grid(
    tree_wflow,
    resamples = cann_cvs,
    grid = tree_grid
  )

tuning_metrics <- tree_grid_search %>% collect_metrics()

Tuning

tuning_metrics
# A tibble: 16 × 9
   cost_complexity tree_depth min_n .metric  .estimator  mean     n std_err
             <dbl>      <int> <int> <chr>    <chr>      <dbl> <int>   <dbl>
 1    0.0000000001          1     2 accuracy multiclass 0.613     5 0.00708
 2    0.0000000001          1     2 roc_auc  hand_till  0.680     5 0.00530
 3    0.1                   1     2 accuracy multiclass 0.613     5 0.00708
 4    0.1                   1     2 roc_auc  hand_till  0.680     5 0.00530
 5    0.0000000001         15     2 accuracy multiclass 0.561     5 0.00525
 6    0.0000000001         15     2 roc_auc  hand_till  0.640     5 0.00677
 7    0.1                  15     2 accuracy multiclass 0.613     5 0.00708
 8    0.1                  15     2 roc_auc  hand_till  0.680     5 0.00530
 9    0.0000000001          1    40 accuracy multiclass 0.613     5 0.00708
10    0.0000000001          1    40 roc_auc  hand_till  0.680     5 0.00530
11    0.1                   1    40 accuracy multiclass 0.613     5 0.00708
12    0.1                   1    40 roc_auc  hand_till  0.680     5 0.00530
13    0.0000000001         15    40 accuracy multiclass 0.589     5 0.00432
14    0.0000000001         15    40 roc_auc  hand_till  0.753     5 0.00406
15    0.1                  15    40 accuracy multiclass 0.613     5 0.00708
16    0.1                  15    40 roc_auc  hand_till  0.680     5 0.00530
# ℹ 1 more variable: .config <chr>

Tuning

tuning_metrics %>%
  filter(.metric == "accuracy") %>%
  slice_max(mean)
# A tibble: 6 × 9
  cost_complexity tree_depth min_n .metric  .estimator  mean     n std_err
            <dbl>      <int> <int> <chr>    <chr>      <dbl> <int>   <dbl>
1    0.0000000001          1     2 accuracy multiclass 0.613     5 0.00708
2    0.1                   1     2 accuracy multiclass 0.613     5 0.00708
3    0.1                  15     2 accuracy multiclass 0.613     5 0.00708
4    0.0000000001          1    40 accuracy multiclass 0.613     5 0.00708
5    0.1                   1    40 accuracy multiclass 0.613     5 0.00708
6    0.1                  15    40 accuracy multiclass 0.613     5 0.00708
# ℹ 1 more variable: .config <chr>


tuning_metrics %>%
  filter(.metric == "roc_auc") %>%
  slice_max(mean)
# A tibble: 1 × 9
  cost_complexity tree_depth min_n .metric .estimator  mean     n std_err
            <dbl>      <int> <int> <chr>   <chr>      <dbl> <int>   <dbl>
1    0.0000000001         15    40 roc_auc hand_till  0.753     5 0.00406
# ℹ 1 more variable: .config <chr>

Try it!

Open Activity-Decision-Tree

  1. Fit a final model with the selected hyperparameters

  2. Report some metrics for the final model

  3. Plot the tree (look at the previously provided code)

  4. Interpret the first two levels of splits in plain English.